#include <fstream>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "core/configure/include/configure_parser.h"
#include "core/configure/inferencer_configure.pb.h"
#include "core/predictor/common/utils.h"
#include "core/predictor/framework/infer.h"
#include "paddle_inference_api.h"  // NOLINT

namespace baidu {
namespace paddle_serving {
namespace inference {

using paddle_infer::Config;
using paddle_infer::PrecisionType;
using paddle_infer::Predictor;
using paddle_infer::Tensor;
using paddle_infer::CreatePredictor;

DECLARE_int32(gpuid);
DECLARE_string(precision);
DECLARE_bool(use_calib);

static const int max_batch = 32;
static const int min_subgraph_size = 3;
static PrecisionType precision_type;

PrecisionType GetPrecision(const std::string& precision_data) {
  std::string precision_type = predictor::ToLower(precision_data);
  if (precision_type == "fp32") {
    return PrecisionType::kFloat32;
  } else if (precision_type == "int8") {
    return PrecisionType::kInt8;
  } else if (precision_type == "fp16") {
    return PrecisionType::kHalf;
  }
  return PrecisionType::kFloat32;
}

// Engine Base
class EngineCore {
 public:
  virtual ~EngineCore() {}
  virtual std::vector<std::string> GetInputNames() {
    return _predictor->GetInputNames();
  }

  virtual std::unique_ptr<Tensor> GetInputHandle(const std::string& name) {
    return _predictor->GetInputHandle(name);
  }

  virtual std::vector<std::string> GetOutputNames() {
    return _predictor->GetOutputNames();
  }

  virtual std::unique_ptr<Tensor> GetOutputHandle(const std::string& name) {
    return _predictor->GetOutputHandle(name);
  }

  virtual bool Run() {
    if (!_predictor->Run()) {
      LOG(ERROR) << "Failed call Run with paddle predictor";
      return false;
    }
    return true;
  }

  virtual int create(const configure::EngineDesc& conf) = 0;

  virtual int clone(void* predictor) {
    if (predictor == NULL) {
      LOG(ERROR) << "origin paddle Predictor is null.";
      return -1;
    }
    Predictor* prep = static_cast<Predictor*>(predictor);
    _predictor = prep->Clone();
    if (_predictor.get() == NULL) {
      LOG(ERROR) << "fail to clone paddle predictor: " << predictor;
      return -1;
    }
    return 0;
  }

  virtual void* get() { return _predictor.get(); }

 protected:
  std::shared_ptr<Predictor> _predictor;
};

// Paddle Inference Engine
class PaddleInferenceEngine : public EngineCore {
 public:
  int create(const configure::EngineDesc& engine_conf) {
    std::string model_path = engine_conf.model_dir();
    if (access(model_path.c_str(), F_OK) == -1) {
      LOG(ERROR) << "create paddle predictor failed, path not exits: "
                 << model_path;
      return -1;
    }

    Config config;
    // todo, auto config(zhangjun)
    if (engine_conf.has_encrypted_model() && engine_conf.encrypted_model()) {
      // decrypt model
      std::string model_buffer, params_buffer, key_buffer;
      predictor::ReadBinaryFile(model_path + "/encrypt_model", &model_buffer);
      predictor::ReadBinaryFile(model_path + "/encrypt_params", &params_buffer);
      predictor::ReadBinaryFile(model_path + "/key", &key_buffer);

      auto cipher = paddle::MakeCipher("");
      std::string real_model_buffer = cipher->Decrypt(model_buffer, key_buffer);
      std::string real_params_buffer =
          cipher->Decrypt(params_buffer, key_buffer);
      config.SetModelBuffer(&real_model_buffer[0],
                            real_model_buffer.size(),
                            &real_params_buffer[0],
                            real_params_buffer.size());
    } else if (engine_conf.has_combined_model()) {
      if (!engine_conf.combined_model()) {
        config.SetModel(model_path);
      } else {
        config.SetParamsFile(model_path + "/__params__");
        config.SetProgFile(model_path + "/__model__");
      }
    } else {
      config.SetParamsFile(model_path + "/__params__");
      config.SetProgFile(model_path + "/__model__");
    }

    config.SwitchSpecifyInputNames(true);
    config.SetCpuMathLibraryNumThreads(1);
    if (engine_conf.has_use_gpu() && engine_conf.use_gpu()) {
      // 2000MB GPU memory
      config.EnableUseGpu(2000, FLAGS_gpuid);
    }
    precision_type = GetPrecision(FLAGS_precision);

    if (engine_conf.has_use_trt() && engine_conf.use_trt()) {
      if (!engine_conf.has_use_gpu() || !engine_conf.use_gpu()) {
        config.EnableUseGpu(2000, FLAGS_gpuid);
      }
      config.EnableTensorRtEngine(1 << 20,
                                  max_batch,
                                  min_subgraph_size,
                                  precision_type,
                                  false,
                                  FLAGS_use_calib);
      LOG(INFO) << "create TensorRT predictor";
    }

    if (engine_conf.has_use_lite() && engine_conf.use_lite()) {
      config.EnableLiteEngine(precision_type, true);
    }

    if ((!engine_conf.has_use_lite() && !engine_conf.has_use_gpu()) ||
        (engine_conf.has_use_lite() && !engine_conf.use_lite() &&
         engine_conf.has_use_gpu() && !engine_conf.use_gpu())) {
      if (precision_type == PrecisionType::kInt8) {
        config.EnableMkldnnQuantizer();
      } else if (precision_type == PrecisionType::kHalf) {
        config.EnableMkldnnBfloat16();
      }
    }

    if (engine_conf.has_use_xpu() && engine_conf.use_xpu()) {
      // 2 MB l3 cache
      config.EnableXpu(2 * 1024 * 1024);
    }
    if (engine_conf.has_enable_ir_optimization() &&
        !engine_conf.enable_ir_optimization()) {
      config.SwitchIrOptim(false);
    } else {
      config.SwitchIrOptim(true);
    }

    if (engine_conf.has_enable_memory_optimization() &&
        engine_conf.enable_memory_optimization()) {
      config.EnableMemoryOptim();
    }

    predictor::AutoLock lock(predictor::GlobalCreateMutex::instance());
    _predictor = CreatePredictor(config);
    if (NULL == _predictor.get()) {
      LOG(ERROR) << "create paddle predictor failed, path: " << model_path;
      return -1;
    }

    VLOG(2) << "create paddle predictor sucess, path: " << model_path;
    return 0;
  }
};

}  // namespace inference
}  // namespace paddle_serving
}  // namespace baidu